{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GraphRAG Quickstart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites\n",
    "Install 3rd party packages, not part of the Python Standard Library, to run the notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install devtools requests tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import json\n",
    "import time\n",
    "from pathlib import Path\n",
    "\n",
    "import requests\n",
    "from devtools import pprint\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (REQUIRED) User Configuration\n",
    "Set the API subscription key, API base endpoint, and some file directory names that will be referenced later in the notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### API subscription key\n",
    "\n",
    "APIM supports multiple forms of authentication and access control (e.g. managed identity). For this notebook demonstration, we will use a **[subscription key](https://learn.microsoft.com/en-us/azure/api-management/api-management-subscriptions)**. To locate this key, visit the Azure Portal. The subscription key can be found under `<my_resource_group> --> <API Management service> --> <APIs> --> <Subscriptions> --> <Built-in all-access subscription> Primary Key`. For multiple API users, individual subscription keys can be generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ocp_apim_subscription_key = getpass.getpass(\n",
    "    \"Enter the subscription key to the GraphRag APIM:\"\n",
    ")\n",
    "\n",
    "\"\"\"\n",
    "\"Ocp-Apim-Subscription-Key\": \n",
    "    This is a custom HTTP header used by Azure API Management service (APIM) to \n",
    "    authenticate API requests. The value for this key should be set to the subscription \n",
    "    key provided by the Azure APIM instance in your GraphRAG resource group.\n",
    "\"\"\"\n",
    "headers = {\"Ocp-Apim-Subscription-Key\": ocp_apim_subscription_key}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Setup directories and API endpoint\n",
    "\n",
    "For demonstration purposes, please use the provided `get-wiki-articles.py` script to download a small set of wikipedia articles or provide your own data (graphrag requires txt files to be utf-8 encoded)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "These parameters must be defined by the notebook user:\n",
    "\n",
    "- file_directory: a local directory of text files. The file structure should be flat,\n",
    "                  with no nested directories. (i.e. file_directory/file1.txt, file_directory/file2.txt, etc.)\n",
    "- storage_name:   a unique name to identify a blob storage container in Azure where files\n",
    "                  from `file_directory` will be uploaded.\n",
    "- index_name:     a unique name to identify a single graphrag knowledge graph index.\n",
    "                  Note: Multiple indexes may be created from the same `storage_name` blob storage container.\n",
    "- endpoint:       the base/endpoint URL for the GraphRAG API (this is the Gateway URL found in the APIM resource).\n",
    "\"\"\"\n",
    "\n",
    "file_directory = \"\"\n",
    "storage_name = \"\"\n",
    "index_name = \"\"\n",
    "endpoint = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert (\n",
    "    file_directory != \"\" and storage_name != \"\" and index_name != \"\" and endpoint != \"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload Files\n",
    "\n",
    "For a demonstration of how to index data in graphrag, we first need to ingest a few files into graphrag. **Multiple filetypes are now supported via the [MarkItDown](https://github.com/microsoft/markitdown) library.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def upload_files(\n",
    "    file_directory: str,\n",
    "    container_name: str,\n",
    "    batch_size: int = 100,\n",
    "    overwrite: bool = True,\n",
    "    max_retries: int = 5,\n",
    ") -> requests.Response | list[Path]:\n",
    "    \"\"\"\n",
    "    Upload files to a blob storage container.\n",
    "\n",
    "    Args:\n",
    "    file_directory - a local directory of .txt files to upload. All files must be in utf-8 encoding.\n",
    "    container_name - a unique name for the Azure storage container.\n",
    "    batch_size - the number of files to upload in a single batch.\n",
    "    overwrite - whether or not to overwrite files if they already exist in the storage container.\n",
    "    max_retries - the maximum number of times to retry uploading a batch of files if the API is busy.\n",
    "\n",
    "    NOTE: Uploading files may sometimes fail if the blob container was recently deleted\n",
    "    (i.e. a few seconds before. The solution \"in practice\" is to sleep a few seconds and try again.\n",
    "    \"\"\"\n",
    "    url = endpoint + \"/data\"\n",
    "\n",
    "    def upload_batch(\n",
    "        files: list, container_name: str, overwrite: bool, max_retries: int\n",
    "    ) -> requests.Response:\n",
    "        for _ in range(max_retries):\n",
    "            response = requests.post(\n",
    "                url=url,\n",
    "                files=files,\n",
    "                params={\"container_name\": container_name, \"overwrite\": overwrite},\n",
    "                headers=headers,\n",
    "            )\n",
    "            # API may be busy, retry\n",
    "            if response.status_code == 500:\n",
    "                print(\"API busy. Sleeping and will try again.\")\n",
    "                time.sleep(10)\n",
    "                continue\n",
    "            return response\n",
    "        return response\n",
    "\n",
    "    batch_files = []\n",
    "    filepaths = list(Path(file_directory).iterdir())\n",
    "    for file in tqdm(filepaths):\n",
    "        # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding\n",
    "        if (not file.is_file()):\n",
    "            print(f\"Skipping invalid file: {file}\")\n",
    "            continue\n",
    "        batch_files.append(\n",
    "            (\"files\", open(file=file, mode=\"rb\"))\n",
    "        )\n",
    "        # upload batch of files\n",
    "        if len(batch_files) == batch_size:\n",
    "            response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
    "            # if response is not ok, return early\n",
    "            if not response.ok:\n",
    "                return response\n",
    "            batch_files.clear()\n",
    "    # upload last batch of remaining files\n",
    "    if len(batch_files) > 0:\n",
    "        response = upload_batch(batch_files, container_name, overwrite, max_retries)\n",
    "    return response\n",
    "\n",
    "\n",
    "response = upload_files(\n",
    "    file_directory=file_directory,\n",
    "    container_name=storage_name,\n",
    "    batch_size=100,\n",
    "    overwrite=True,\n",
    ")\n",
    "if not response.ok:\n",
    "    print(response.text)\n",
    "else:\n",
    "    print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build an Index\n",
    "\n",
    "After data files have been uploaded, we can construct a knowledge graph by building a search index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_index(\n",
    "    storage_name: str,\n",
    "    index_name: str,\n",
    ") -> requests.Response:\n",
    "    \"\"\"Create a search index.\n",
    "    This function kicks off a job that builds a knowledge graph index from files located in a blob storage container.\n",
    "    \"\"\"\n",
    "    url = endpoint + \"/index\"\n",
    "    return requests.post(\n",
    "        url,\n",
    "        params={\n",
    "            \"index_container_name\": index_name,\n",
    "            \"storage_container_name\": storage_name,\n",
    "        },\n",
    "        headers=headers,\n",
    "    )\n",
    "\n",
    "\n",
    "response = build_index(storage_name=storage_name, index_name=index_name)\n",
    "print(response)\n",
    "if response.ok:\n",
    "    print(response.text)\n",
    "else:\n",
    "    print(f\"Failed to submit job.\\nStatus: {response.text}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check status of an indexing job\n",
    "\n",
    "Please wait for your index to reach 100 percent completion before continuing on to the next section (running queries). You may rerun the next cell multiple times to monitor status. Note: the indexing speed of graphrag is directly correlated to the TPM quota of the Azure OpenAI model you are using."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def index_status(index_name: str) -> requests.Response:\n",
    "    url = endpoint + f\"/index/status/{index_name}\"\n",
    "    return requests.get(url, headers=headers)\n",
    "\n",
    "\n",
    "response = index_status(index_name)\n",
    "pprint(response.json())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Query\n",
    "\n",
    "Once an indexing job is complete, the knowledge graph is ready to query. Two types of queries (global and local) are currently supported. We encourage you to try both and experience the difference in responses. Note that query response time is also correlated to the TPM quota of the Azure OpenAI model you are using."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# a helper function to parse out the result from a query response\n",
    "def parse_query_response(\n",
    "    response: requests.Response, return_context_data: bool = False\n",
    ") -> requests.Response | dict[list[dict]]:\n",
    "    \"\"\"\n",
    "    Print response['result'] value and return context data.\n",
    "    \"\"\"\n",
    "    if response.ok:\n",
    "        print(json.loads(response.text)[\"result\"])\n",
    "        if return_context_data:\n",
    "            return json.loads(response.text)[\"context_data\"]\n",
    "        return response\n",
    "    else:\n",
    "        print(response.reason)\n",
    "        print(response.content)\n",
    "        return response"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Global Query \n",
    "\n",
    "Global queries are resource-intensive, but provide good responses to questions that require an understanding of the dataset as a whole."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def global_search(\n",
    "    index_name: str | list[str], query: str, community_level: int\n",
    ") -> requests.Response:\n",
    "    \"\"\"Run a global query over the knowledge graph(s) associated with one or more indexes\"\"\"\n",
    "    url = endpoint + \"/query/global\"\n",
    "    # optional parameter: community level to query the graph at (default for global query = 1)\n",
    "    request = {\n",
    "        \"index_name\": index_name,\n",
    "        \"query\": query,\n",
    "        \"community_level\": community_level,\n",
    "    }\n",
    "    return requests.post(url, json=request, headers=headers)\n",
    "\n",
    "\n",
    "# perform a global query\n",
    "global_response = global_search(\n",
    "    index_name=index_name,\n",
    "    query=\"Summarize the main topics found in this data\",\n",
    "    community_level=1,\n",
    ")\n",
    "global_response_data = parse_query_response(global_response, return_context_data=True)\n",
    "global_response_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Local Query\n",
    "\n",
    "Local search queries are best suited for narrow-focused questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def local_search(\n",
    "    index_name: str | list[str], query: str, community_level: int\n",
    ") -> requests.Response:\n",
    "    \"\"\"Run a local query over the knowledge graph(s) associated with one or more indexes\"\"\"\n",
    "    url = endpoint + \"/query/local\"\n",
    "    # optional parameter: community level to query the graph at (default for local query = 2)\n",
    "    request = {\n",
    "        \"index_name\": index_name,\n",
    "        \"query\": query,\n",
    "        \"community_level\": community_level,\n",
    "    }\n",
    "    return requests.post(url, json=request, headers=headers)\n",
    "\n",
    "\n",
    "# perform a local query\n",
    "local_response = local_search(\n",
    "    index_name=index_name,\n",
    "    query=\"Summarize the main topics found in this data\",\n",
    "    community_level=2,\n",
    ")\n",
    "local_response_data = parse_query_response(local_response, return_context_data=True)\n",
    "local_response_data"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
